-
Notifications
You must be signed in to change notification settings - Fork 606
[RFC] Add TrainingModule and SGD JNI + PTE-only Training Workflow #12247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/12247
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 11 PendingAs of commit b4d7ada with merge base ba19c75 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
a6e15a7
to
c008409
Compare
c008409
to
aba87ed
Compare
@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77939473. |
facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor); | ||
}; | ||
|
||
class JEValue : public facebook::jni::JavaClass<JEValue> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm does this not already exist for inference?
* @param nesterov Whether to use Nesterov momentum | ||
* @return new {@link org.pytorch.executorch.SGD} object | ||
*/ | ||
public static SGD create( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesnt have to be this diff but would it be more "java-y" to have builder classes?
new SGDBuilder().learning_rate().buildSGD();
} | ||
|
||
@DoNotStrip | ||
private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob q. What are these "native" apis?
As title, adds wrappers together with unit test based on XOR train.cpp example.
aba87ed
to
b4d7ada
Compare
Summary
Adds JNI for SGD and TrainingModule, including a unit test that mirrors train.cpp for a simple XOR example. Also makes the following change:
android_test_setup.sh
to match the pushd-popd directory movement for consistency and flexibility. This is also used to fix errors with generating the XOR files.Training dependencies are already enabled for Java JNI library, so we skip adding additional guard flags.
Test plan
Updated XOR tests that check .pte only convergence workflow.
For the XOR tests, the device logs will show convergence values: